{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# PyTorch & MLflow quickstart"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Welcome to PrimeHub!\n",
    "\n",
    "In this quickstart, we will perfome following actions to train, manage, and deploy the model: \n",
    "\n",
    "1. Train a neural network that classifies images with <a target=\"_blank\" href=\"https://www.mlflow.org/docs/latest/python_api/mlflow.pytorch.html#mlflow.pytorch.autolog\">MLflow autologging API</a> enabled.\n",
    "1. Register the trained model to PrimeHub <a target=\"_blank\" href=\"https://docs.primehub.io/docs/model-management\">Model Management</a>.\n",
    "1. Deploy the managed model on PrimeHub <a target=\"_blank\" href=\"https://docs.primehub.io/docs/model-deployment-feature\">Model Deployments</a>.\n",
    "1. Test deployed model."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Prerequisites\n",
    "1. Enable <a target=\"_blank\" href=\"https://docs.primehub.io/docs/model-management\">Model Management</a>.\n",
    "1. Enable <a target=\"_blank\" href=\"https://docs.primehub.io/docs/model-deployment-feature\">Model Deployments</a>.\n",
    "\n",
    "**Contact your admin if any prerequisite is not enabled yet.**"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Train a neural network that classifies images with <a target=\"_blank\" href=\"https://www.mlflow.org/docs/latest/python_api/mlflow.pytorch.html#mlflow.pytorch.autolog\">MLflow autologging API</a> enabled"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In order to use mlflow pytorch autolog, it needs to install the `mlflow` and `pytorch-lightning` first."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requirement already satisfied: mlflow==1.18.0 in /opt/conda/lib/python3.7/site-packages (1.18.0)\n",
      "Requirement already satisfied: pytorch-lightning==1.3.0 in /opt/conda/lib/python3.7/site-packages (1.3.0)\n",
      "Requirement already satisfied: querystring-parser in /opt/conda/lib/python3.7/site-packages (from mlflow==1.18.0) (1.2.4)\n",
      "Requirement already satisfied: packaging in /opt/conda/lib/python3.7/site-packages (from mlflow==1.18.0) (20.9)\n",
      "Requirement already satisfied: prometheus-flask-exporter in /opt/conda/lib/python3.7/site-packages (from mlflow==1.18.0) (0.18.2)\n",
      "Requirement already satisfied: gunicorn in /opt/conda/lib/python3.7/site-packages (from mlflow==1.18.0) (20.1.0)\n",
      "Requirement already satisfied: pytz in /opt/conda/lib/python3.7/site-packages (from mlflow==1.18.0) (2021.1)\n",
      "Requirement already satisfied: gitpython>=2.1.0 in /opt/conda/lib/python3.7/site-packages (from mlflow==1.18.0) (3.1.18)\n",
      "Requirement already satisfied: protobuf>=3.7.0 in /opt/conda/lib/python3.7/site-packages (from mlflow==1.18.0) (3.11.4)\n",
      "Requirement already satisfied: sqlalchemy in /opt/conda/lib/python3.7/site-packages (from mlflow==1.18.0) (1.3.23)\n",
      "Requirement already satisfied: sqlparse>=0.3.1 in /opt/conda/lib/python3.7/site-packages (from mlflow==1.18.0) (0.4.1)\n",
      "Requirement already satisfied: click>=7.0 in /opt/conda/lib/python3.7/site-packages (from mlflow==1.18.0) (8.0.1)\n",
      "Requirement already satisfied: numpy in /opt/conda/lib/python3.7/site-packages (from mlflow==1.18.0) (1.21.0)\n",
      "Requirement already satisfied: pandas in /opt/conda/lib/python3.7/site-packages (from mlflow==1.18.0) (1.0.5)\n",
      "Requirement already satisfied: Flask in /opt/conda/lib/python3.7/site-packages (from mlflow==1.18.0) (2.0.1)\n",
      "Requirement already satisfied: entrypoints in /opt/conda/lib/python3.7/site-packages (from mlflow==1.18.0) (0.3)\n",
      "Requirement already satisfied: databricks-cli>=0.8.7 in /opt/conda/lib/python3.7/site-packages (from mlflow==1.18.0) (0.9.1)\n",
      "Requirement already satisfied: cloudpickle in /opt/conda/lib/python3.7/site-packages (from mlflow==1.18.0) (1.4.1)\n",
      "Requirement already satisfied: alembic<=1.4.1 in /opt/conda/lib/python3.7/site-packages (from mlflow==1.18.0) (1.4.1)\n",
      "Requirement already satisfied: requests>=2.17.3 in /opt/conda/lib/python3.7/site-packages (from mlflow==1.18.0) (2.25.1)\n",
      "Requirement already satisfied: pyyaml>=5.1 in /opt/conda/lib/python3.7/site-packages (from mlflow==1.18.0) (5.4.1)\n",
      "Requirement already satisfied: docker>=4.0.0 in /opt/conda/lib/python3.7/site-packages (from mlflow==1.18.0) (5.0.0)\n",
      "Requirement already satisfied: tensorboard!=2.5.0,>=2.2.0 in /opt/conda/lib/python3.7/site-packages (from pytorch-lightning==1.3.0) (2.6.0)\n",
      "Requirement already satisfied: torch>=1.4 in /opt/conda/lib/python3.7/site-packages (from pytorch-lightning==1.3.0) (1.8.0+cpu)\n",
      "Requirement already satisfied: future>=0.17.1 in /opt/conda/lib/python3.7/site-packages (from pytorch-lightning==1.3.0) (0.18.2)\n",
      "Requirement already satisfied: torchmetrics>=0.2.0 in /opt/conda/lib/python3.7/site-packages (from pytorch-lightning==1.3.0) (0.5.0)\n",
      "Requirement already satisfied: fsspec[http]>=2021.4.0 in /opt/conda/lib/python3.7/site-packages (from pytorch-lightning==1.3.0) (2021.6.1)\n",
      "Requirement already satisfied: pyDeprecate==0.3.0 in /opt/conda/lib/python3.7/site-packages (from pytorch-lightning==1.3.0) (0.3.0)\n",
      "Requirement already satisfied: tqdm>=4.41.0 in /opt/conda/lib/python3.7/site-packages (from pytorch-lightning==1.3.0) (4.61.1)\n",
      "Requirement already satisfied: python-dateutil in /opt/conda/lib/python3.7/site-packages (from alembic<=1.4.1->mlflow==1.18.0) (2.8.1)\n",
      "Requirement already satisfied: Mako in /opt/conda/lib/python3.7/site-packages (from alembic<=1.4.1->mlflow==1.18.0) (1.1.4)\n",
      "Requirement already satisfied: python-editor>=0.3 in /opt/conda/lib/python3.7/site-packages (from alembic<=1.4.1->mlflow==1.18.0) (1.0.4)\n",
      "Requirement already satisfied: importlib-metadata in /opt/conda/lib/python3.7/site-packages (from click>=7.0->mlflow==1.18.0) (4.5.0)\n",
      "Requirement already satisfied: six>=1.10.0 in /opt/conda/lib/python3.7/site-packages (from databricks-cli>=0.8.7->mlflow==1.18.0) (1.16.0)\n",
      "Requirement already satisfied: configparser>=0.3.5 in /opt/conda/lib/python3.7/site-packages (from databricks-cli>=0.8.7->mlflow==1.18.0) (5.0.2)\n",
      "Requirement already satisfied: tabulate>=0.7.7 in /opt/conda/lib/python3.7/site-packages (from databricks-cli>=0.8.7->mlflow==1.18.0) (0.8.9)\n",
      "Requirement already satisfied: websocket-client>=0.32.0 in /opt/conda/lib/python3.7/site-packages (from docker>=4.0.0->mlflow==1.18.0) (0.57.0)\n",
      "Requirement already satisfied: aiohttp in /opt/conda/lib/python3.7/site-packages (from fsspec[http]>=2021.4.0->pytorch-lightning==1.3.0) (3.7.4.post0)\n",
      "Requirement already satisfied: typing-extensions>=3.7.4.0 in /opt/conda/lib/python3.7/site-packages (from gitpython>=2.1.0->mlflow==1.18.0) (3.10.0.0)\n",
      "Requirement already satisfied: gitdb<5,>=4.0.1 in /opt/conda/lib/python3.7/site-packages (from gitpython>=2.1.0->mlflow==1.18.0) (4.0.7)\n",
      "Requirement already satisfied: smmap<5,>=3.0.1 in /opt/conda/lib/python3.7/site-packages (from gitdb<5,>=4.0.1->gitpython>=2.1.0->mlflow==1.18.0) (3.0.5)\n",
      "Requirement already satisfied: setuptools in /opt/conda/lib/python3.7/site-packages (from protobuf>=3.7.0->mlflow==1.18.0) (49.6.0.post20210108)\n",
      "Requirement already satisfied: chardet<5,>=3.0.2 in /opt/conda/lib/python3.7/site-packages (from requests>=2.17.3->mlflow==1.18.0) (4.0.0)\n",
      "Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.7/site-packages (from requests>=2.17.3->mlflow==1.18.0) (2020.6.20)\n",
      "Requirement already satisfied: idna<3,>=2.5 in /opt/conda/lib/python3.7/site-packages (from requests>=2.17.3->mlflow==1.18.0) (2.10)\n",
      "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /opt/conda/lib/python3.7/site-packages (from requests>=2.17.3->mlflow==1.18.0) (1.26.5)\n",
      "Requirement already satisfied: werkzeug>=0.11.15 in /opt/conda/lib/python3.7/site-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.0) (2.0.1)\n",
      "Requirement already satisfied: absl-py>=0.4 in /opt/conda/lib/python3.7/site-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.0) (0.13.0)\n",
      "Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /opt/conda/lib/python3.7/site-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.0) (1.8.0)\n",
      "Requirement already satisfied: google-auth<2,>=1.6.3 in /opt/conda/lib/python3.7/site-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.0) (1.35.0)\n",
      "Requirement already satisfied: wheel>=0.26 in /opt/conda/lib/python3.7/site-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.0) (0.36.2)\n",
      "Requirement already satisfied: grpcio>=1.24.3 in /opt/conda/lib/python3.7/site-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.0) (1.39.0)\n",
      "Requirement already satisfied: markdown>=2.6.8 in /opt/conda/lib/python3.7/site-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.0) (3.3.4)\n",
      "Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /opt/conda/lib/python3.7/site-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.0) (0.4.5)\n",
      "Requirement already satisfied: tensorboard-data-server<0.7.0,>=0.6.0 in /opt/conda/lib/python3.7/site-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.0) (0.6.1)\n",
      "Requirement already satisfied: cachetools<5.0,>=2.0.0 in /opt/conda/lib/python3.7/site-packages (from google-auth<2,>=1.6.3->tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.0) (4.2.2)\n",
      "Requirement already satisfied: pyasn1-modules>=0.2.1 in /opt/conda/lib/python3.7/site-packages (from google-auth<2,>=1.6.3->tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.0) (0.2.8)\n",
      "Requirement already satisfied: rsa<5,>=3.1.4 in /opt/conda/lib/python3.7/site-packages (from google-auth<2,>=1.6.3->tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.0) (4.7.2)\n",
      "Requirement already satisfied: requests-oauthlib>=0.7.0 in /opt/conda/lib/python3.7/site-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.0) (1.3.0)\n",
      "Requirement already satisfied: pyasn1<0.5.0,>=0.4.6 in /opt/conda/lib/python3.7/site-packages (from pyasn1-modules>=0.2.1->google-auth<2,>=1.6.3->tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.0) (0.4.8)\n",
      "Requirement already satisfied: oauthlib>=3.0.0 in /opt/conda/lib/python3.7/site-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.0) (3.1.1)\n",
      "Requirement already satisfied: async-timeout<4.0,>=3.0 in /opt/conda/lib/python3.7/site-packages (from aiohttp->fsspec[http]>=2021.4.0->pytorch-lightning==1.3.0) (3.0.1)\n",
      "Requirement already satisfied: yarl<2.0,>=1.0 in /opt/conda/lib/python3.7/site-packages (from aiohttp->fsspec[http]>=2021.4.0->pytorch-lightning==1.3.0) (1.6.3)\n",
      "Requirement already satisfied: multidict<7.0,>=4.5 in /opt/conda/lib/python3.7/site-packages (from aiohttp->fsspec[http]>=2021.4.0->pytorch-lightning==1.3.0) (5.1.0)\n",
      "Requirement already satisfied: attrs>=17.3.0 in /opt/conda/lib/python3.7/site-packages (from aiohttp->fsspec[http]>=2021.4.0->pytorch-lightning==1.3.0) (21.2.0)\n",
      "Requirement already satisfied: Jinja2>=3.0 in /opt/conda/lib/python3.7/site-packages (from Flask->mlflow==1.18.0) (3.0.1)\n",
      "Requirement already satisfied: itsdangerous>=2.0 in /opt/conda/lib/python3.7/site-packages (from Flask->mlflow==1.18.0) (2.0.1)\n",
      "Requirement already satisfied: MarkupSafe>=2.0 in /opt/conda/lib/python3.7/site-packages (from Jinja2>=3.0->Flask->mlflow==1.18.0) (2.0.1)\n",
      "Requirement already satisfied: zipp>=0.5 in /opt/conda/lib/python3.7/site-packages (from importlib-metadata->click>=7.0->mlflow==1.18.0) (3.4.1)\n",
      "Requirement already satisfied: pyparsing>=2.0.2 in /opt/conda/lib/python3.7/site-packages (from packaging->mlflow==1.18.0) (2.4.7)\n",
      "Requirement already satisfied: prometheus-client in /opt/conda/lib/python3.7/site-packages (from prometheus-flask-exporter->mlflow==1.18.0) (0.11.0)\n"
     ]
    }
   ],
   "source": [
    "!pip install mlflow==1.18.0 pytorch-lightning==1.3.0"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Firstly, let's import libraries."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "from datetime import datetime\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "from torch.nn import functional as F\n",
    "from torch import nn\n",
    "from torch.utils.data import DataLoader\n",
    "from torchvision.datasets import MNIST\n",
    "from torchvision import transforms\n",
    "import mlflow\n",
    "import pytorch_lightning as pl"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Set the experiment name so the execution of this notebook will be associated with `pytorch-quickstart` experiment. Also, enable the `MLflow autologging` to record parameters, metrics, and models automatically.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO: 'pytorch-quickstart' does not exist. Creating a new experiment\n"
     ]
    }
   ],
   "source": [
    "mlflow.set_experiment(\"pytorch-quickstart\")\n",
    "mlflow.pytorch.autolog()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Load and prepare the MNIST dataset. Convert the samples to tensor and normalize them."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--2021-08-24 04:15:18--  http://www.di.ens.fr/~lelarge/MNIST.tar.gz\n",
      "Resolving www.di.ens.fr (www.di.ens.fr)... 129.199.99.14\n",
      "Connecting to www.di.ens.fr (www.di.ens.fr)|129.199.99.14|:80... connected.\n",
      "HTTP request sent, awaiting response... 302 Found\n",
      "Location: https://www.di.ens.fr/~lelarge/MNIST.tar.gz [following]\n",
      "--2021-08-24 04:15:19--  https://www.di.ens.fr/~lelarge/MNIST.tar.gz\n",
      "Connecting to www.di.ens.fr (www.di.ens.fr)|129.199.99.14|:443... connected.\n",
      "HTTP request sent, awaiting response... 200 OK\n",
      "Length: unspecified [application/x-gzip]\n",
      "Saving to: ‘MNIST.tar.gz’\n",
      "\n",
      "MNIST.tar.gz            [<=>                 ]  33.20M  5.77MB/s    in 7.8s    \n",
      "\n",
      "2021-08-24 04:15:28 (4.24 MB/s) - ‘MNIST.tar.gz’ saved [34813078]\n",
      "\n",
      "MNIST/\n",
      "MNIST/raw/\n",
      "MNIST/raw/train-labels-idx1-ubyte\n",
      "MNIST/raw/t10k-labels-idx1-ubyte.gz\n",
      "MNIST/raw/t10k-labels-idx1-ubyte\n",
      "MNIST/raw/t10k-images-idx3-ubyte.gz\n",
      "MNIST/raw/train-images-idx3-ubyte\n",
      "MNIST/raw/train-labels-idx1-ubyte.gz\n",
      "MNIST/raw/t10k-images-idx3-ubyte\n",
      "MNIST/raw/train-images-idx3-ubyte.gz\n",
      "MNIST/processed/\n",
      "MNIST/processed/training.pt\n",
      "MNIST/processed/test.pt\n"
     ]
    }
   ],
   "source": [
    "!wget www.di.ens.fr/~lelarge/MNIST.tar.gz\n",
    "!tar -zxvf MNIST.tar.gz\n",
    "!rm MNIST.tar.gz\n",
    "\n",
    "transform=transforms.Compose([transforms.ToTensor(),\n",
    "                              transforms.Normalize((0.1307,), (0.3081,))])\n",
    "\n",
    "mnist_train = MNIST(os.getcwd(), train=True, download=True, transform=transform)\n",
    "mnist_train = DataLoader(mnist_train, batch_size=64, shuffle=True)\n",
    "mnist_test = MNIST(os.getcwd(), train=False, download=True, transform=transform)\n",
    "mnist_test = DataLoader(mnist_test, batch_size=64)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Set device on GPU if available, else CPU."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "if torch.cuda.is_available():\n",
    "    device = torch.cuda.current_device()\n",
    "    print(torch.cuda.device(device))\n",
    "    print('Device Count:', torch.cuda.device_count())\n",
    "    print('Device Name: {}'.format(torch.cuda.get_device_name(device)))\n",
    "else:\n",
    "    device = 'cpu'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`MLflow autologging` supports pytorch lightning. So, we need to build the model class based on pytorch lightning."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "class PyTorchModel(pl.LightningModule):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "\n",
    "        # mnist images are (1, 28, 28) (channels, width, height)\n",
    "        self.layer_1 = nn.Linear(28 * 28, 128)\n",
    "        self.layer_2 = nn.Linear(128, 256)\n",
    "        self.layer_3 = nn.Linear(256, 10)\n",
    "        \n",
    "    def forward(self, x):\n",
    "        batch_size, channels, width, height = x.size()\n",
    "\n",
    "        # (b, 1, 28, 28) -> (b, 1*28*28)\n",
    "        x = x.view(batch_size, -1)\n",
    "        x = self.layer_1(x)\n",
    "        x = F.relu(x)\n",
    "        x = self.layer_2(x)\n",
    "        x = F.relu(x)\n",
    "        x = self.layer_3(x)\n",
    "\n",
    "        x = F.softmax(x, dim=1)\n",
    "        return x\n",
    "\n",
    "\n",
    "    def training_step(self, batch, batch_idx):\n",
    "        x, y = batch\n",
    "        logits = self(x)\n",
    "        loss = F.nll_loss(logits, y)\n",
    "        return loss\n",
    "    \n",
    "    def test_step(self, batch, batch_idx):\n",
    "        x, y = batch\n",
    "        logits = self(x)\n",
    "        _, predicted = torch.max(logits.data, 1)\n",
    "        total = y.size(0)\n",
    "        correct = (predicted == y).sum().item()\n",
    "        return {\n",
    "            'total': total,\n",
    "            'correct': correct\n",
    "        }\n",
    "        \n",
    "    def test_epoch_end(self, outputs):\n",
    "        total = np.array([x['total'] for x in outputs]).sum()\n",
    "        correct = np.array([x['correct'] for x in outputs]).sum()\n",
    "        self.log(\"Accuracy\", 100 * correct / total)\n",
    "    \n",
    "    def configure_optimizers(self):\n",
    "        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)\n",
    "        return optimizer"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Build the PyTorch lightning trainer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: False, used: False\n",
      "TPU available: False, using: 0 TPU cores\n"
     ]
    }
   ],
   "source": [
    "model = PyTorchModel().to(device)\n",
    "trainer = pl.Trainer(max_epochs=2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Train the model to minimize the loss. The mlflow function will log the model result automatically. In addition, we need to manually log the model class python file. And finally, we use `trainer.test` method checks the model performance in the testing dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "  | Name    | Type   | Params\n",
      "-----------------------------------\n",
      "0 | layer_1 | Linear | 100 K \n",
      "1 | layer_2 | Linear | 33.0 K\n",
      "2 | layer_3 | Linear | 2.6 K \n",
      "-----------------------------------\n",
      "136 K     Trainable params\n",
      "0         Non-trainable params\n",
      "136 K     Total params\n",
      "0.544     Total estimated model params size (MB)\n",
      "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/utilities/distributed.py:69: UserWarning: The dataloader, train dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 4 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n",
      "  warnings.warn(*args, **kwargs)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a7b652d303af4c75bf4eb7922d5ff291",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Training: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/utilities/distributed.py:69: UserWarning: The dataloader, test dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 4 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n",
      "  warnings.warn(*args, **kwargs)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b14f8cb47a1a422fa70924b7583c78a8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Testing: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--------------------------------------------------------------------------------\n",
      "DATALOADER:0 TEST RESULTS\n",
      "{'Accuracy': 94.77}\n",
      "--------------------------------------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "with mlflow.start_run():\n",
    "    trainer.fit(model, mnist_train)\n",
    "    \n",
    "    model_class_file_content = \"\"\"\n",
    "import torch\n",
    "from torch.nn import functional as F\n",
    "from torch import nn\n",
    "\n",
    "class PyTorchModel(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "\n",
    "        # mnist images are (1, 28, 28) (channels, width, height)\n",
    "        self.layer_1 = nn.Linear(28 * 28, 128)\n",
    "        self.layer_2 = nn.Linear(128, 256)\n",
    "        self.layer_3 = nn.Linear(256, 10)\n",
    "\n",
    "    def forward(self, x):\n",
    "        batch_size, channels, width, height = x.size()\n",
    "\n",
    "        # (b, 1, 28, 28) -> (b, 1*28*28)\n",
    "        x = x.view(batch_size, -1)\n",
    "        x = self.layer_1(x)\n",
    "        x = F.relu(x)\n",
    "        x = self.layer_2(x)\n",
    "        x = F.relu(x)\n",
    "        x = self.layer_3(x)\n",
    "\n",
    "        x = F.softmax(x, dim=1)\n",
    "        return x\n",
    "\"\"\"\n",
    "    model_class_file = open(\"ModelClass.py\", \"w\")\n",
    "    model_class_file.write(model_class_file_content)\n",
    "    model_class_file.close()\n",
    "    mlflow.log_artifact(\"ModelClass.py\", artifact_path=\"model/data\")\n",
    "    os.remove(\"ModelClass.py\")\n",
    "    \n",
    "    trainer.test(model, mnist_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Register the trained model to PrimeHub <a target=\"_blank\" href=\"https://docs.primehub.io/docs/model-management\">Model Management</a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, we've trained the model and it should be automatically exported to MLflow server.\n",
    "Next, back to PrimeHub and select `Models`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"img/pytorch_mlflow/2-1-menu.png\"/>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In the models page, click on the `MLflow UI` button."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"img/pytorch_mlflow/2-2-model-list.png\"/>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In the `MLflow UI`, switch to `Experiments` tab."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"img/pytorch_mlflow/2-3-mlflow-ui.png\"/>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Select our specified experiment name `pytorch-quickstart`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"img/pytorch_mlflow/2-4-experiment.png\"/>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "It shows all runs in `pytorch-quickstart` experiment, now click on our executed run."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"img/pytorch_mlflow/2-5-run-list.png\"/>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Both parameters, metrics, and artifacts can be found in this page."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"img/pytorch_mlflow/2-6-run-info.png\"/>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Scroll down to the `Artifacts` section. Click on the exported `model` and `Register Model` button."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"img/pytorch_mlflow/2-7-artifacts.png\"/>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In the model selector, choose the `Create New Model`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"img/pytorch_mlflow/2-8-register-model.png\"/>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Fill in the [Model Name] field with [pytorch-quickstart] and click on `Register` button."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"img/pytorch_mlflow/2-9-register-model.png\"/>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can see our model is successfully registered as version `1`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"img/pytorch_mlflow/2-10-registered.png\"/>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Back and refresh the models page in the PrimeHub UI, now we can see our model `pytorch-quickstart` is managed in model list."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"img/pytorch_mlflow/2-11-model-list.png\"/>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Deploy the managed model on PrimeHub <a target=\"_blank\" href=\"https://docs.primehub.io/docs/model-deployment-feature\">Model Deployments</a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, click on our managed model name `pytorch-quickstart`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"img/pytorch_mlflow/2-11-model-list.png\"/>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "It shows all versions of `pytorch-quickstart` model, let's click on the `Deploy` button of `Version 1`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"img/pytorch_mlflow/3-1-version-list.png\"/>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In the deployment selector, choose the `Create new deployment` and click on `OK` button."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"img/pytorch_mlflow/3-2-new-deployment.png\"/>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will be directed to `Create Deployment` page. And the `Model URI` field will be auto fill-in with registered model scheme (`models:/pytorch-quickstart/1`)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"img/pytorch_mlflow/3-3-model-uri.png\"/>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next,\n",
    "1. Fill in the [Deployment Name] field with [pytorch-quickstart].\n",
    "1. Select the [Model Image] field with [PyTorch server]; this is a pre-packaged model server image that can serve PyTorch model."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"img/pytorch_mlflow/3-4-name-model-image.png\"/>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Choose the Instance Type, the minimal requirements in this quickstart is `CPU: 0.5 / Memory: 1 G / GPU: 0`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"img/pytorch/3-5-resource.png\"/>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then, click on `Deploy` button.\n",
    "\n",
    "Our model is deploying, let's click on the `pytorch-quickstart` deployment cell."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"img/pytorch/3-6-deployment-list.png\"/>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In the deployment detail page, we can see the `Status` is `Deploying`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"img/pytorch_mlflow/3-5-deploying.png\"/>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Wait for a while and our model is `Deployed` now!\n",
    "\n",
    "To test our deployment, let's copy the value of `Endpoint` (`https://.../predictions`)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"img/pytorch_mlflow/3-6-deployed.png\"/>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. Test deployed model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Before testing, let's display the test image. It is a 28x28 grayscale image represented as numpy.ndarray."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAM4ElEQVR4nO3db6xU9Z3H8c9nWZoY6QNQce9alC7xgc3GgCIxQTfXkDYsPsBGuikPGjZpvH2Apo0NWeM+wIeN2bZZn5DcRlO6YW1IqEqMcSHYSBq18WJQLr0BkbBwyxVsMCmYGES/++AeN1ecc2acMzNn4Pt+JZOZOd85Z74Z7odz5vyZnyNCAK5+f9N0AwAGg7ADSRB2IAnCDiRB2IEk/naQb2abXf9An0WEW02vtWa3vdb2EdvHbD9WZ1kA+svdHme3PU/SUUnfljQt6U1JGyPiTxXzsGYH+qwfa/ZVko5FxPGIuCjpt5LW11gegD6qE/abJJ2a83y6mPYFtsdsT9ieqPFeAGqqs4Ou1abClzbTI2Jc0rjEZjzQpDpr9mlJS+Y8/4ak0/XaAdAvdcL+pqRbbX/T9tckfV/S7t60BaDXut6Mj4hLth+W9D+S5kl6JiIO96wzAD3V9aG3rt6M7+xA3/XlpBoAVw7CDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBJdj88uSbZPSDov6VNJlyJiZS+aAtB7tcJeuC8i/tKD5QDoIzbjgSTqhj0k7bF9wPZYqxfYHrM9YXui5nsBqMER0f3M9t9HxGnbiyXtlfRIROyveH33bwagIxHhVtNrrdkj4nRxf1bSc5JW1VkegP7pOuy2r7X99c8fS/qOpMleNQagt+rsjb9R0nO2P1/Of0fEyz3pCkDP1frO/pXfjO/sQN/15Ts7gCsHYQeSIOxAEoQdSIKwA0n04kKYFDZs2FBae+ihhyrnPX36dGX9448/rqzv2LGjsv7++++X1o4dO1Y5L/JgzQ4kQdiBJAg7kARhB5Ig7EAShB1IgrADSXDVW4eOHz9eWlu6dOngGmnh/PnzpbXDhw8PsJPhMj09XVp78sknK+edmLhyf0WNq96A5Ag7kARhB5Ig7EAShB1IgrADSRB2IAmuZ+9Q1TXrt99+e+W8U1NTlfXbbrutsn7HHXdU1kdHR0trd999d+W8p06dqqwvWbKksl7HpUuXKusffPBBZX1kZKTr9z558mRl/Uo+zl6GNTuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJMH17FeBhQsXltaWL19eOe+BAwcq63fddVdXPXWi3e/lHz16tLLe7vyFRYsWldY2b95cOe+2bdsq68Os6+vZbT9j+6ztyTnTFtnea/vd4r78rw3AUOhkM/7XktZeNu0xSfsi4lZJ+4rnAIZY27BHxH5J5y6bvF7S9uLxdkkP9LgvAD3W7bnxN0bEjCRFxIztxWUvtD0maazL9wHQI32/ECYixiWNS+ygA5rU7aG3M7ZHJKm4P9u7lgD0Q7dh3y1pU/F4k6QXetMOgH5pe5zd9rOSRiVdL+mMpK2Snpe0U9LNkk5K+l5EXL4Tr9Wy2IxHxx588MHK+s6dOyvrk5OTpbX77ruvct5z59r+OQ+tsuPsbb+zR8TGktKaWh0BGChOlwWSIOxAEoQdSIKwA0kQdiAJLnFFYxYvLj3LWpJ06NChWvNv2LChtLZr167Kea9kDNkMJEfYgSQIO5AEYQeSIOxAEoQdSIKwA0kwZDMa0+7nnG+44YbK+ocfflhZP3LkyFfu6WrGmh1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkuB6dvTV6tWrS2uvvPJK5bzz58+vrI+OjlbW9+/fX1m/WnE9O5AcYQeSIOxAEoQdSIKwA0kQdiAJwg4kwfXs6Kt169aV1todR9+3b19l/fXXX++qp6zartltP2P7rO3JOdOesP1n2weLW/m/KICh0Mlm/K8lrW0x/ZcRsby4vdTbtgD0WtuwR8R+SecG0AuAPqqzg+5h2+8Um/kLy15ke8z2hO2JGu8FoKZuw75N0jJJyyXNSPp52QsjYjwiVkbEyi7fC0APdBX2iDgTEZ9GxGeSfiVpVW/bAtBrXYXd9sicp9+VNFn2WgDDoe1xdtvPShqVdL3taUlbJY3aXi4pJJ2Q9KM+9oghds0111TW165tdSBn1sWLFyvn3bp1a2X9k08+qazji9qGPSI2tpj8dB96AdBHnC4LJEHYgSQIO5AEYQeSIOxAElziilq2bNlSWV+xYkVp7eWXX66c97XXXuuqJ7TGmh1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkmDIZlS6//77K+vPP/98Zf2jjz4qrVVd/ipJb7zxRmUdrTFkM5AcYQeSIOxAEoQdSIKwA0kQdiAJwg4kwfXsyV133XWV9aeeeqqyPm/evMr6Sy+Vj/nJcfTBYs0OJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0lwPftVrt1x8HbHuu+8887K+nvvvVdZr7pmvd286E7X17PbXmL797anbB+2/eNi+iLbe22/W9wv7HXTAHqnk834S5J+GhG3Sbpb0mbb35L0mKR9EXGrpH3FcwBDqm3YI2ImIt4qHp+XNCXpJknrJW0vXrZd0gP9ahJAfV/p3HjbSyWtkPRHSTdGxIw0+x+C7cUl84xJGqvXJoC6Og677QWSdkn6SUT81W65D+BLImJc0nixDHbQAQ3p6NCb7fmaDfqOiPhdMfmM7ZGiPiLpbH9aBNALbdfsnl2FPy1pKiJ+Mae0W9ImST8r7l/oS4eoZdmyZZX1dofW2nn00Ucr6xxeGx6dbMavlvQDSYdsHyymPa7ZkO+0/UNJJyV9rz8tAuiFtmGPiD9IKvuCvqa37QDoF06XBZIg7EAShB1IgrADSRB2IAl+SvoqcMstt5TW9uzZU2vZW7Zsqay/+OKLtZaPwWHNDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJcJz9KjA2Vv6rXzfffHOtZb/66quV9UH+FDnqYc0OJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0lwnP0KcM8991TWH3nkkQF1gisZa3YgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSKKT8dmXSPqNpL+T9Jmk8Yj4T9tPSHpI0gfFSx+PiJf61Whm9957b2V9wYIFXS+73fjpFy5c6HrZGC6dnFRzSdJPI+It21+XdMD23qL2y4j4j/61B6BXOhmffUbSTPH4vO0pSTf1uzEAvfWVvrPbXipphaQ/FpMetv2O7WdsLyyZZ8z2hO2JWp0CqKXjsNteIGmXpJ9ExF8lbZO0TNJyza75f95qvogYj4iVEbGyB/0C6FJHYbc9X7NB3xERv5OkiDgTEZ9GxGeSfiVpVf/aBFBX27DbtqSnJU1FxC/mTB+Z87LvSprsfXsAeqWTvfGrJf1A0iHbB4tpj0vaaHu5pJB0QtKP+tIhann77bcr62vWrKmsnzt3rpftoEGd7I3/gyS3KHFMHbiCcAYdkARhB5Ig7EAShB1IgrADSRB2IAkPcshd24zvC/RZRLQ6VM6aHciCsANJEHYgCcIOJEHYgSQIO5AEYQeSGPSQzX+R9L9znl9fTBtGw9rbsPYl0Vu3etnbLWWFgZ5U86U3tyeG9bfphrW3Ye1LorduDao3NuOBJAg7kETTYR9v+P2rDGtvw9qXRG/dGkhvjX5nBzA4Ta/ZAQwIYQeSaCTsttfaPmL7mO3HmuihjO0Ttg/ZPtj0+HTFGHpnbU/OmbbI9l7b7xb3LcfYa6i3J2z/ufjsDtpe11BvS2z/3vaU7cO2f1xMb/Szq+hrIJ/bwL+z254n6aikb0ualvSmpI0R8aeBNlLC9glJKyOi8RMwbP+TpAuSfhMR/1hMe1LSuYj4WfEf5cKI+Lch6e0JSReaHsa7GK1oZO4w45IekPSvavCzq+jrXzSAz62JNfsqScci4nhEXJT0W0nrG+hj6EXEfkmXD8myXtL24vF2zf6xDFxJb0MhImYi4q3i8XlJnw8z3uhnV9HXQDQR9psknZrzfFrDNd57SNpj+4DtsaabaeHGiJiRZv94JC1uuJ/LtR3Ge5AuG2Z8aD67boY/r6uJsLf6faxhOv63OiLukPTPkjYXm6voTEfDeA9Ki2HGh0K3w5/X1UTYpyUtmfP8G5JON9BHSxFxurg/K+k5Dd9Q1Gc+H0G3uD/bcD//b5iG8W41zLiG4LNrcvjzJsL+pqRbbX/T9tckfV/S7gb6+BLb1xY7TmT7Wknf0fANRb1b0qbi8SZJLzTYyxcMyzDeZcOMq+HPrvHhzyNi4DdJ6zS7R/49Sf/eRA8lff2DpLeL2+Gme5P0rGY36z7R7BbRDyVdJ2mfpHeL+0VD1Nt/STok6R3NBmukod7u0exXw3ckHSxu65r+7Cr6GsjnxumyQBKcQQckQdiBJAg7kARhB5Ig7EAShB1IgrADSfwfrLwRQMBWyxMAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "array = np.array([[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.32941176470588235, 0.7254901960784313, 0.6235294117647059, 0.592156862745098, 0.23529411764705882, 0.1411764705882353, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.8705882352941177, 0.996078431372549, 0.996078431372549, 0.996078431372549, 0.996078431372549, 0.9450980392156862, 0.7764705882352941, 0.7764705882352941, 0.7764705882352941, 0.7764705882352941, 0.7764705882352941, 0.7764705882352941, 0.7764705882352941, 0.7764705882352941, 0.6666666666666666, 0.20392156862745098, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.2627450980392157, 0.4470588235294118, 0.2823529411764706, 0.4470588235294118, 0.6392156862745098, 0.8901960784313725, 0.996078431372549, 0.8823529411764706, 0.996078431372549, 0.996078431372549, 0.996078431372549, 0.9803921568627451, 0.8980392156862745, 0.996078431372549, 0.996078431372549, 0.5490196078431373, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.06666666666666667, 0.25882352941176473, 0.054901960784313725, 0.2627450980392157, 0.2627450980392157, 0.2627450980392157, 0.23137254901960785, 0.08235294117647059, 0.9254901960784314, 0.996078431372549, 0.41568627450980394, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3254901960784314, 0.9921568627450981, 0.8196078431372549, 0.07058823529411765, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.08627450980392157, 0.9137254901960784, 1.0, 0.3254901960784314, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5058823529411764, 0.996078431372549, 0.9333333333333333, 0.17254901960784313, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.23137254901960785, 0.9764705882352941, 0.996078431372549, 0.24313725490196078, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5215686274509804, 0.996078431372549, 0.7333333333333333, 0.0196078431372549, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.03529411764705882, 0.803921568627451, 0.9725490196078431, 0.22745098039215686, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.49411764705882355, 0.996078431372549, 0.7137254901960784, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.29411764705882354, 0.984313725490196, 0.9411764705882353, 0.2235294117647059, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.07450980392156863, 0.8666666666666667, 0.996078431372549, 0.6509803921568628, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.011764705882352941, 0.796078431372549, 0.996078431372549, 0.8588235294117647, 0.13725490196078433, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.14901960784313725, 0.996078431372549, 0.996078431372549, 0.30196078431372547, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.12156862745098039, 0.8784313725490196, 0.996078431372549, 0.45098039215686275, 0.00392156862745098, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5215686274509804, 0.996078431372549, 0.996078431372549, 0.20392156862745098, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.23921568627450981, 0.9490196078431372, 0.996078431372549, 0.996078431372549, 0.20392156862745098, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.4745098039215686, 0.996078431372549, 0.996078431372549, 0.8588235294117647, 0.1568627450980392, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.4745098039215686, 0.996078431372549, 0.8117647058823529, 0.07058823529411765, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]])\n",
    "plt.imshow(array, cmap='gray')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then, replace the `<MODEL_DEPLOYMENT_ENDPOINT>` in the below cell with `https://.../predictions` (Endpoint value copied from deployment detail page)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%env ENDPOINT=<MODEL_DEPLOYMENT_ENDPOINT>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Run below cell to send request to deployed model endpoint."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "response=!curl -X POST $ENDPOINT \\\n",
    "    -H 'Content-Type: application/json' \\\n",
    "    -d '{ \"data\": {\"ndarray\": [[[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.32941176470588235, 0.7254901960784313, 0.6235294117647059, 0.592156862745098, 0.23529411764705882, 0.1411764705882353, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.8705882352941177, 0.996078431372549, 0.996078431372549, 0.996078431372549, 0.996078431372549, 0.9450980392156862, 0.7764705882352941, 0.7764705882352941, 0.7764705882352941, 0.7764705882352941, 0.7764705882352941, 0.7764705882352941, 0.7764705882352941, 0.7764705882352941, 0.6666666666666666, 0.20392156862745098, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.2627450980392157, 0.4470588235294118, 0.2823529411764706, 0.4470588235294118, 0.6392156862745098, 0.8901960784313725, 0.996078431372549, 0.8823529411764706, 0.996078431372549, 0.996078431372549, 0.996078431372549, 0.9803921568627451, 0.8980392156862745, 0.996078431372549, 0.996078431372549, 0.5490196078431373, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.06666666666666667, 0.25882352941176473, 0.054901960784313725, 0.2627450980392157, 0.2627450980392157, 0.2627450980392157, 0.23137254901960785, 0.08235294117647059, 0.9254901960784314, 0.996078431372549, 0.41568627450980394, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3254901960784314, 0.9921568627450981, 0.8196078431372549, 0.07058823529411765, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.08627450980392157, 0.9137254901960784, 1.0, 0.3254901960784314, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5058823529411764, 0.996078431372549, 0.9333333333333333, 0.17254901960784313, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.23137254901960785, 0.9764705882352941, 0.996078431372549, 0.24313725490196078, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5215686274509804, 0.996078431372549, 0.7333333333333333, 0.0196078431372549, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.03529411764705882, 0.803921568627451, 0.9725490196078431, 0.22745098039215686, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.49411764705882355, 0.996078431372549, 0.7137254901960784, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.29411764705882354, 0.984313725490196, 0.9411764705882353, 0.2235294117647059, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.07450980392156863, 0.8666666666666667, 0.996078431372549, 0.6509803921568628, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.011764705882352941, 0.796078431372549, 0.996078431372549, 0.8588235294117647, 0.13725490196078433, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.14901960784313725, 0.996078431372549, 0.996078431372549, 0.30196078431372547, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.12156862745098039, 0.8784313725490196, 0.996078431372549, 0.45098039215686275, 0.00392156862745098, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5215686274509804, 0.996078431372549, 0.996078431372549, 0.20392156862745098, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.23921568627450981, 0.9490196078431372, 0.996078431372549, 0.996078431372549, 0.20392156862745098, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.4745098039215686, 0.996078431372549, 0.996078431372549, 0.8588235294117647, 0.1568627450980392, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.4745098039215686, 0.996078431372549, 0.8117647058823529, 0.07058823529411765, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]]] } }'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In the following cells, we will parse the model response to find the highest prediction probability."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "result = json.loads(response[-1])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Print out response data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'data': {'names': [],\n",
       "  'ndarray': [[0.00039837765507400036,\n",
       "    1.8940910422315937e-06,\n",
       "    0.027976594865322113,\n",
       "    0.0040565794333815575,\n",
       "    9.291115077303402e-08,\n",
       "    0.0001759410952217877,\n",
       "    2.5345872245452483e-07,\n",
       "    0.9668492078781128,\n",
       "    0.0003237354685552418,\n",
       "    0.0002173809625674039]]},\n",
       " 'meta': {'requestPath': {'model': 'infuseai/pytorch-prepackaged:v0.2.0'}}}"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The model predicts the number is 7\n"
     ]
    }
   ],
   "source": [
    "print(f\"The model predicts the number is {np.argmax(result['data']['ndarray'][0])}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Congratulations!\n",
    "\n",
    "We have versioned our trained model and further deploy it as an endpoint service that can respond to requests anytime from everywhere!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "get_started.ipynb",
   "provenance": [],
   "toc_visible": true,
   "version": "0.3.2"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
